import tensorflow as tf

layers = tf.keras.layers

h0 = [tf.zeros([4, 64])]
x = tf.random.normal([4, 80, 100])
xt = x[:, 0, :]
# print(xt)
cell = layers.SimpleRNNCell(64)
out, h1 = cell(xt, h0)
print(len(h1))
print(out.shape, h1[0].shape)
print(id(out), id(h1[0]))

h = h0
# 矩阵分解
y = tf.unstack(x, axis=1)
print(len(y))
print(y[0].shape)
for xt in y:
    out, h = cell(xt, h)
print(out.shape)
